import torch
import numpy as np
from tqdm import tqdm
from ddpm import DDPMSampler

WIDTH = 512
HEIGHT = 512
LATTENTS_WIDTH = WIDTH // 8
LATTENTS_HEIGHT = HEIGHT //8

def generate(prompt, uncond_prompt, input_image=None,
             strength=0.8, do_cfg=True, cfg_scale=7.5, sample_name="ddpm",
             inference_steps=50, models={}, seed=None,
             device=None, idle_device=None, tokenizer=None):
    with torch.no_grad():
        if not (0 < strength <= 1):
            raise ValueError("strength should be between 0 and 1")
        
        if idle_device:
            to_idle = lambda x: x.to(idle_device)
        else:
            to_idle = lambda x: x
        generator = torch.Generator(device=device)
        if seed is None:
            generate.seed()
        else:
            generator.manual_seed(seed)
        clip = models["clip"]
        clip.to(device)
        if do_cfg:
            #无分类器引导
            conditional_tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_length=77).input_ids
            conditional_tokens = torch.tensor(conditional_tokens, dtype=torch.long, device=device)
            conditional_context = clip(conditional_tokens)

            unconditional_tokens = tokenizer.batch_encode_plus([uncond_prompt], padding="max_length", max_length=77).input_ids
            unconditional_tokens = torch.tensor(unconditional_tokens, dtype=torch.long, device=device)
            unconditional_context = clip(unconditional_tokens)

            context = torch.cat([conditional_context, unconditional_context])
        else:
            tokens = tokenizer.batch_encode_plus([prompt], padding="max_length", max_lenght=77).input_ids
            tokens = torch.tensor(tokens, dtype=torch.long, device=device)
            context = clip(tokens)
        to_idle(clip)

        if sample_name == "ddpm":
            sampler = DDPMSampler(generator)
            sampler.set_inference_timesteps(inference_steps)
        else:
            raise ValueError("unknown sampler")
        
        latents_shape = (1, 4, LATTENTS_HEIGHT, LATTENTS_WIDTH)
        if input_image:
            encoder = models["encoder"]
            encoder.to(device)
            input_image_tensor = input_image.resize((WIDTH, HEIGHT))
            input_image_tensor = np.array(input_image_tensor)
            input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32)
            input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
            input_image_tensor = input_image_tensor.unsequeeze(0)
            input_image_tensor = input_image_tensor.permute(0, 3, 1, 2)

            encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
            latents = encoder(input_image_tensor, encoder_noise)
            sampler.set_strength(strength=strength)
            latents = sampler.add_noise(latents, sampler.timesteps[0])
            to_idle(encoder)
        else:
            latents = torch.randn(latents_shape, generator=generator, device=device)

        diffusion = models["diffusion"]
        diffusion.to(device)
        timesteps = tqdm(sampler.time_steps)
        for _, step in enumerate(timesteps):
            time_embedding = get_time_embedding(step).to(device)
            model_input = latents
            if do_cfg:
                model_input = model_input.repeat(2, 1, 1, 1)
            model_output = diffusion(model_input, context, time_embedding)

            if do_cfg:
                output_cond, output_uncond = model_output.chunk(2)
                model_output = cfg_scale * (output_cond - output_uncond) + output_uncond
            latents = sampler.step(step, latents, model_output)
        to_idle(diffusion)

        decoder = models["decoder"]
        decoder.to(device)
        images = decoder(latents)
        to_idle(decoder)

        images = rescale(images, (-1, 1), (0, 255), clamp=True)
        images = images.permute(0, 2, 3, 1)
        images = images.to("cpu", torch.uint8).numpy()
        return images[0]

def rescale(x, old_range, new_range, clamp=False):
    old_min, old_max = old_range
    new_min, new_max = new_range
    x -= old_min
    x *= (new_max - new_min) / (old_max - old_min)
    x += new_min
    if clamp:
        x = x.clamp(new_min, new_max)
    return x

def get_time_embedding(time_step):
    freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
    x = torch.tensor([time_step], dtype=torch.float32)[:, None] * freqs[None]
    return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)